from functools import partial

from .activations import *


def get_act_layer(act_layer: str, **kwargs):
    if act_layer == "relu":
        return partial(nn.ReLU, inplace=True)
    elif act_layer == "lrelu" or act_layer == "leaky_relu":
        return partial(nn.LeakyReLU, negative_slope=kwargs.get("negative_slope", 0.01), inplace=True)
    elif act_layer == "prelu":
        return partial(nn.PReLU, num_parameters=kwargs.get("num_parameters", 1), init=kwargs.get("init", 0.25))
    elif act_layer == "maxout":  # New Maxout case
        num_units = kwargs.get("num_units", 2)  # Default: 2 linear pieces
        return partial(Maxout, num_units=num_units)
    elif act_layer == "gelu":
        return GELU
    elif act_layer == "hgelu" or act_layer == "gclu":
        return HGELU
    elif act_layer == "quick_gclu":
        return QuickGCLU
    elif act_layer == "gclu_tanh":
        return GCLUTanh
    elif act_layer == "seqhgelu":
        return SequecialHGELU
    elif act_layer == "adrelu" or act_layer == "drelu":
        return ADReLU  
    elif act_layer == "DyReLUB":
        return DyReLUB       
    else:
        raise NotImplementedError(act_layer)

